import os
import pandas as pd
import numpy as np
import pickle
import glob

def load_results(base_dir="./results"):
    """Load all pickle results into a single dictionary."""
    results_data = {}
    
    # Find all result directories
    result_dirs = sorted(glob.glob(f"{base_dir}/xi*Q*forced*_maxconvidx*"))
    
    print(f"Found {len(result_dirs)} result directories")
    
    for result_dir in result_dirs:
        # Parse parameters from directory name
        dir_name = os.path.basename(result_dir)
        params = {}
        
        # Extract xi value - handle both formats (xi1000Q... or xi=1000_Q...)
        xi_match = None
        if 'xi=' in dir_name:
            # Format with equals sign
            parts = dir_name.split('_')
            for part in parts:
                if part.startswith('xi='):
                    xi_match = part[3:]  # Remove 'xi=' prefix
                    break
        else:
            # Format without equals sign (e.g., xi1000Q200...)
            xi_parts = dir_name.split('Q')
            if len(xi_parts) > 0 and xi_parts[0].startswith('xi'):
                xi_match = xi_parts[0][2:]  # Remove 'xi' prefix
        
        if xi_match:
            try:
                params['xi'] = int(xi_match)
            except ValueError:
                print(f"Warning: Could not parse xi value from '{dir_name}'")
                continue
        else:
            print(f"Warning: Could not find xi value in '{dir_name}'")
            continue
        
        # Extract Q value
        q_match = None
        if 'Q=' in dir_name:
            parts = dir_name.split('_')
            for part in parts:
                if part.startswith('Q='):
                    q_match = part[2:]  # Remove 'Q=' prefix
                    break
        else:
            # Format like xi1000Q200
            q_parts = dir_name.split('forced')
            if len(q_parts) > 0:
                q_str = q_parts[0]
                q_idx = q_str.find('Q')
                if q_idx >= 0:
                    q_match = q_str[q_idx+1:]  # Extract everything after 'Q'
        
        if q_match:
            try:
                params['Q'] = int(q_match)
            except ValueError:
                print(f"Warning: Could not parse Q value from '{dir_name}'")
                continue
        else:
            print(f"Warning: Could not find Q value in '{dir_name}'")
            continue
        
        # Extract forced value and max_conv_idx
        if 'forced' in dir_name:
            forced_parts = dir_name.split('forced')
            if len(forced_parts) > 1:
                forced_val = forced_parts[1].split('_')[0]
                try:
                    params['forced'] = int(forced_val)
                except ValueError:
                    params['forced'] = 0  # Default if parsing fails
            
        if 'maxconvidx' in dir_name:
            conv_parts = dir_name.split('maxconvidx')
            if len(conv_parts) > 1:
                max_conv_val = conv_parts[1]
                try:
                    params['max_conv_idx'] = int(max_conv_val)
                except ValueError:
                    params['max_conv_idx'] = 100  # Default if parsing fails
        
        # Load pickle file
        pickle_file = os.path.join(result_dir, 'simplified_caching.pkl')
        if os.path.exists(pickle_file):
            try:
                with open(pickle_file, 'rb') as f:
                    data = pickle.load(f)
                    
                    # Basic validation of data structure
                    required_keys = ['tail_lru_none', 'vanilla_lru', 'thre_lru']
                    if not all(key in data for key in required_keys):
                        print(f"Warning: Missing required keys in data from {pickle_file}")
                        print(f"Found keys: {list(data.keys())}")
                        continue
                    
                    # Check if data has expected percentiles
                    any_policy = data[required_keys[0]]
                    print(f"Available percentiles in {dir_name}: {list(any_policy.keys())}")
                    
                    params['data'] = data
                    
                    # Extract the parameter combination as a key
                    key = f"xi={params['xi']}_Q={params['Q']}"
                    results_data[key] = params
                    print(f"Loaded results for {key}")
            except Exception as e:
                print(f"Error loading {pickle_file}: {e}")
    
    if not results_data:
        print("No valid results were loaded. Check the results directory and file formats.")
    else:
        print(f"Successfully loaded data for {len(results_data)} parameter combinations")
    
    return results_data

def find_nearest_capacity_idx(actual_capacities, target_capacity):
    """Find the index of the nearest capacity in the actual capacities list."""
    if not actual_capacities:
        return None
    
    # Find the capacity in actual_capacities that is closest to target_capacity
    closest_capacity = min(actual_capacities, key=lambda x: abs(x - target_capacity))
    
    # Get the index of the closest capacity
    return actual_capacities.index(closest_capacity)

def create_ratio_tables(results_data, percentiles=[90, 95], capacities=[3000, 5000, 7000, 9000]):
    """Create tables of T-LRU/LRU and T-LRU/threshold-LRU ratios."""
    # Initialize dictionaries to store ratio tables
    tlru_lru_ratio = {p: {} for p in percentiles}
    tlru_thre_ratio = {p: {} for p in percentiles}
    
    # Sort the xi values
    xi_values = sorted([params['xi'] for params in results_data.values()], reverse=True)
    
    # Determine the actual capacities available in the data
    actual_capacities = None
    for params in results_data.values():
        data = params['data']
        
        # Get the first policy and percentile to check its length (all should be the same)
        first_policy = list(data.keys())[0]
        first_percentile = list(data[first_policy].keys())[0]
        
        # These are the actual capacities used in the run
        if actual_capacities is None:
            # We don't know the actual values, but we know how many there are
            num_capacities = len(data[first_policy][first_percentile])
            print(f"Found {num_capacities} capacity values in the data")
            
            # Use the standard capacities from simplified_caching.py if the count matches
            standard_capacities = [1000, 2000, 4000, 6000, 8000, 10000]
            if num_capacities == len(standard_capacities):
                actual_capacities = standard_capacities
                print(f"Using standard capacities: {actual_capacities}")
            else:
                # Best guess - linear spacing
                actual_capacities = list(range(1000, 1000 * (num_capacities + 1), 1000))
                print(f"Using estimated capacities: {actual_capacities}")
        break
    
    if actual_capacities is None:
        print("Warning: Could not determine actual capacities, using defaults")
        actual_capacities = [1000, 2000, 4000, 6000, 8000, 10000]
    
    # For each percentile, create a table with xi values as rows and capacities as columns
    for p in percentiles:
        # Initialize DataFrames with xi values as index
        tlru_lru_ratio[p] = pd.DataFrame(index=xi_values, columns=capacities)
        tlru_thre_ratio[p] = pd.DataFrame(index=xi_values, columns=capacities)
        
        # Fill in the ratio tables
        for key, params in results_data.items():
            xi = params['xi']
            data = params['data']
            
            # Extract data for each policy
            try:
                tlru_data = data['tail_lru_none'][p]
                lru_data = data['vanilla_lru'][p]
                thre_data = data['thre_lru'][p]
                
                # For each capacity we want in our table
                for capacity in capacities:
                    # Find the nearest capacity in the actual capacities
                    nearest_capacity_idx = find_nearest_capacity_idx(actual_capacities, capacity)
                    
                    if nearest_capacity_idx is not None and nearest_capacity_idx < len(tlru_data):
                        # Calculate ratios
                        tlru_lru_ratio[p].loc[xi, capacity] = tlru_data[nearest_capacity_idx] / lru_data[nearest_capacity_idx]
                        tlru_thre_ratio[p].loc[xi, capacity] = tlru_data[nearest_capacity_idx] / thre_data[nearest_capacity_idx]
                    else:
                        print(f"Warning: Capacity index {nearest_capacity_idx} out of range for data length {len(tlru_data)}")
            except Exception as e:
                print(f"Error processing data for xi={xi}, percentile={p}: {e}")
    
    return tlru_lru_ratio, tlru_thre_ratio

def create_improvement_tables(results_data, percentiles=[90, 95], capacities=[3000, 5000, 7000, 9000]):
    """Create tables of relative improvement: (T-LRU - LRU)/LRU and (T-LRU - threshold-LRU)/threshold-LRU."""
    # Initialize dictionaries to store improvement tables
    tlru_vs_lru_improvement = {p: {} for p in percentiles}
    tlru_vs_thre_improvement = {p: {} for p in percentiles}
    
    # Sort the xi values
    xi_values = sorted([params['xi'] for params in results_data.values()], reverse=True)
    
    # Determine the actual capacities available in the data
    actual_capacities = None
    for params in results_data.values():
        data = params['data']
        
        # Get the first policy and percentile to check its length (all should be the same)
        first_policy = list(data.keys())[0]
        first_percentile = list(data[first_policy].keys())[0]
        
        # These are the actual capacities used in the run
        if actual_capacities is None:
            # We don't know the actual values, but we know how many there are
            num_capacities = len(data[first_policy][first_percentile])
            
            # Use the standard capacities from simplified_caching.py if the count matches
            standard_capacities = [1000, 2000, 4000, 6000, 8000, 10000]
            if num_capacities == len(standard_capacities):
                actual_capacities = standard_capacities
            else:
                # Best guess - linear spacing
                actual_capacities = list(range(1000, 1000 * (num_capacities + 1), 1000))
        break
    
    if actual_capacities is None:
        print("Warning: Could not determine actual capacities, using defaults")
        actual_capacities = [1000, 2000, 4000, 6000, 8000, 10000]
    
    # For each percentile, create a table with xi values as rows and capacities as columns
    for p in percentiles:
        # Initialize DataFrames with xi values as index
        tlru_vs_lru_improvement[p] = pd.DataFrame(index=xi_values, columns=capacities)
        tlru_vs_thre_improvement[p] = pd.DataFrame(index=xi_values, columns=capacities)
        
        # Fill in the improvement tables
        for key, params in results_data.items():
            xi = params['xi']
            data = params['data']
            
            # Extract data for each policy
            try:
                tlru_data = data['tail_lru_none'][p]
                lru_data = data['vanilla_lru'][p]
                thre_data = data['thre_lru'][p]
                
                # For each capacity we want in our table
                for capacity in capacities:
                    # Find the nearest capacity in the actual capacities
                    nearest_capacity_idx = find_nearest_capacity_idx(actual_capacities, capacity)
                    
                    if nearest_capacity_idx is not None and nearest_capacity_idx < len(tlru_data):
                        # Calculate relative improvements (note: lower is better for uncached tokens)
                        # So the formula is (baseline - tlru) / baseline
                        tlru_vs_lru_improvement[p].loc[xi, capacity] = (lru_data[nearest_capacity_idx] - tlru_data[nearest_capacity_idx]) / lru_data[nearest_capacity_idx]
                        tlru_vs_thre_improvement[p].loc[xi, capacity] = (thre_data[nearest_capacity_idx] - tlru_data[nearest_capacity_idx]) / thre_data[nearest_capacity_idx]
                    else:
                        print(f"Warning: Capacity index {nearest_capacity_idx} out of range for data length {len(tlru_data)}")
            except Exception as e:
                print(f"Error processing data for xi={xi}, percentile={p}: {e}")
    
    return tlru_vs_lru_improvement, tlru_vs_thre_improvement

def create_example_style_table(tlru_lru_ratio, percentiles=[90, 95]):
    """Create a table styled similarly to the example in the image."""
    # Get all xi values and capacities
    raw_xi_values = sorted(tlru_lru_ratio[percentiles[0]].index.tolist())
    capacities = tlru_lru_ratio[percentiles[0]].columns.tolist()
    
    # Sort xi values numerically
    xi_with_nums = [(xi, xi) for xi in raw_xi_values]  # These are already numeric values
    xi_with_nums.sort(key=lambda x: x[1])
    xi_values = [x[0] for x in xi_with_nums]
    
    # Create column multi-index structure
    column_tuples = []
    for xi in xi_values:
        for p in percentiles:
            column_tuples.append((f"ξ = {xi}", f"p{p}"))
    
    # Create multi-index columns
    multi_columns = pd.MultiIndex.from_tuples(column_tuples, names=[None, None])
    
    # Create the DataFrame with capacities as index
    df = pd.DataFrame(index=capacities, columns=multi_columns)
    df.index.name = "Capacity"
    
    # Fill in the values
    for capacity in capacities:
        for xi in xi_values:
            for p in percentiles:
                # Handle potential missing/non-numeric values
                try:
                    value = tlru_lru_ratio[p].loc[xi, capacity]
                    # Ensure it's a numeric value
                    df.loc[capacity, (f"ξ = {xi}", f"p{p}")] = pd.to_numeric(value, errors='coerce')
                except (KeyError, ValueError, TypeError) as e:
                    print(f"Warning: Could not get value for xi={xi}, p={p}, capacity={capacity}: {e}")
                    df.loc[capacity, (f"ξ = {xi}", f"p{p}")] = np.nan
    
    return df

def create_latex_table(df, caption, label, color_scale=True):
    """
    Create a LaTeX table in the requested format.
    
    Args:
        df: DataFrame with capacities as index, and MultiIndex columns with (xi, percentile)
        caption: Table caption
        label: Table label
        color_scale: Whether to apply color scale to cells
    
    Returns:
        String with LaTeX table
    """
    # Get unique xi values and percentiles from column multi-index
    xi_values = sorted(set([col[0] for col in df.columns]))
    percentiles = sorted(set([col[1] for col in df.columns]))
    
    # Extract numerical xi values for proper sorting
    xi_nums = []
    for xi in xi_values:
        try:
            # Extract just the number from "ξ = 14000"
            xi_num = int(xi.split("=")[1].strip())
            xi_nums.append((xi, xi_num))
        except (ValueError, IndexError):
            # Use a large value for non-numeric xi to put at the end
            xi_nums.append((xi, float('inf')))
    
    # Sort xi values based on their numeric values
    xi_nums.sort(key=lambda x: x[1])
    sorted_xi_values = [item[0] for item in xi_nums]
    sorted_xi_nums = [item[1] for item in xi_nums]
    
    # Start building the LaTeX table
    latex = "\\begin{table}[!htp]\\centering\n"
    latex += f"\\caption{{{caption}}}\n"
    latex += "\\scriptsize\n"
    
    # Calculate the number of columns: 1 for capacity + len(percentiles) for each xi value
    num_cols = 1 + len(sorted_xi_values) * len(percentiles)
    latex += f"\\begin{{tabular}}{{{'r' * num_cols}}}\\toprule\n"
    
    # Create header row for xi values
    header1 = "&"
    for xi, xi_num in zip(sorted_xi_values, sorted_xi_nums):
        header1 += f"\\multicolumn{{{len(percentiles)}}}{{c}}{{$\\xi$ = {xi_num}}} &"
    header1 = header1[:-1]  # Remove trailing &
    latex += header1 + " \\\\\\cmidrule{1-" + str(num_cols) + "}\n"
    
    # Create header row for percentiles
    header2 = "Capacity &"
    for _ in sorted_xi_values:
        for p in percentiles:
            header2 += f"{p} &"
    header2 = header2[:-1]  # Remove trailing &
    latex += header2 + " \\\\\\midrule\n"
    
    # Add rows for each capacity
    for capacity in df.index:
        row = f"{capacity} &"
        for xi in sorted_xi_values:
            for p in percentiles:
                try:
                    value = df.loc[capacity, (xi, p)]
                    
                    # Skip if value is NaN
                    if pd.isna(value):
                        cell = " &"
                        continue
                        
                    # Apply color scale if requested
                    if color_scale:
                        # Calculate color based on value (greener for lower values)
                        if value < 0.85:
                            color = "57bb8a"  # Dark green
                        elif value < 0.9:
                            color = "85cdaa"  # Medium green
                        elif value < 0.92:
                            color = "96d4b6"  # Light medium green
                        elif value < 0.94:
                            color = "b8e2ce"  # Light green
                        elif value < 0.96:
                            color = "c9e9d9"  # Very light green
                        elif value < 0.98:
                            color = "dcf1e6"  # Extremely light green
                        else:
                            color = "f2faf6"  # Almost white green
                        
                        cell = f"\\cellcolor[HTML]{{{color}}}{value:.3f} &"
                    else:
                        cell = f"{value:.3f} &"
                    
                    row += cell
                except:
                    row += " &"
        
        row = row[:-1]  # Remove trailing &
        latex += row + " \\\\\n"
    
    # Close the table
    latex += "\\bottomrule\n"
    latex += f"\\end{{tabular}} \\label{{{label}}}\n"
    latex += "\\end{table}"
    
    return latex

def create_latex_improvement_table(df, caption, label):
    """
    Create a LaTeX table for improvement values (with different color scale).
    
    Args:
        df: DataFrame with capacities as index, and MultiIndex columns with (xi, percentile)
        caption: Table caption
        label: Table label
    
    Returns:
        String with LaTeX table
    """
    # Get unique xi values and percentiles from column multi-index
    xi_values = sorted(set([col[0] for col in df.columns]))
    percentiles = sorted(set([col[1] for col in df.columns]))
    
    # Extract numerical xi values for proper sorting
    xi_nums = []
    for xi in xi_values:
        try:
            # Extract just the number from "ξ = 14000"
            xi_num = int(xi.split("=")[1].strip())
            xi_nums.append((xi, xi_num))
        except (ValueError, IndexError):
            # Use a large value for non-numeric xi to put at the end
            xi_nums.append((xi, float('inf')))
    
    # Sort xi values based on their numeric values
    xi_nums.sort(key=lambda x: x[1])
    sorted_xi_values = [item[0] for item in xi_nums]
    sorted_xi_nums = [item[1] for item in xi_nums]
    
    # Start building the LaTeX table
    latex = "\\begin{table}[!htp]\\centering\n"
    latex += f"\\caption{{{caption}}}\n"
    latex += "\\scriptsize\n"
    
    # Calculate the number of columns: 1 for capacity + len(percentiles) for each xi value
    num_cols = 1 + len(sorted_xi_values) * len(percentiles)
    latex += f"\\begin{{tabular}}{{{'r' * num_cols}}}\\toprule\n"
    
    # Create header row for xi values
    header1 = "&"
    for xi, xi_num in zip(sorted_xi_values, sorted_xi_nums):
        header1 += f"\\multicolumn{{{len(percentiles)}}}{{c}}{{$\\xi$ = {xi_num}}} &"
    header1 = header1[:-1]  # Remove trailing &
    latex += header1 + " \\\\\\cmidrule{1-" + str(num_cols) + "}\n"
    
    # Create header row for percentiles
    header2 = "Capacity &"
    for _ in sorted_xi_values:
        for p in percentiles:
            header2 += f"{p} &"
    header2 = header2[:-1]  # Remove trailing &
    latex += header2 + " \\\\\\midrule\n"
    
    # Add rows for each capacity
    for capacity in df.index:
        row = f"{capacity} &"
        for xi in sorted_xi_values:
            for p in percentiles:
                try:
                    value = df.loc[capacity, (xi, p)]
                    
                    # Skip if value is NaN
                    if pd.isna(value):
                        cell = " &"
                        continue
                        
                    # Apply color scale based on improvement percentage
                    # Convert to percentage for display
                    pct_value = value * 100
                    
                    # Calculate color based on improvement (greener for higher values)
                    if pct_value > 20:
                        color = "57bb8a"  # Dark green (>20% improvement)
                    elif pct_value > 15:
                        color = "85cdaa"  # Medium green (15-20% improvement) 
                    elif pct_value > 10:
                        color = "96d4b6"  # Light medium green (10-15% improvement)
                    elif pct_value > 5:
                        color = "b8e2ce"  # Light green (5-10% improvement)
                    elif pct_value > 2:
                        color = "c9e9d9"  # Very light green (2-5% improvement)
                    elif pct_value > 0:
                        color = "dcf1e6"  # Extremely light green (0-2% improvement)
                    elif pct_value == 0:
                        color = "ffffff"  # White (no improvement)
                    else:
                        color = "ffd6d6"  # Light red (negative improvement - worse)
                    
                    cell = f"\\cellcolor[HTML]{{{color}}}{pct_value:.1f}\\% &"
                    row += cell
                except:
                    row += " &"
        
        row = row[:-1]  # Remove trailing &
        latex += row + " \\\\\n"
    
    # Close the table
    latex += "\\bottomrule\n"
    latex += f"\\end{{tabular}} \\label{{{label}}}\n"
    latex += "\\end{table}"
    
    return latex

def main():
    # Load all results
    results_data = load_results()
    
    if not results_data:
        print("No result files found!")
        return
    
    # Define parameters for tables
    percentiles = [90,95] 
    capacities = [1000, 2000, 4000, 6000, 8000, 10000]   # As shown in your example
    
    # Verify percentiles are available in the data
    all_available_percentiles = set()
    for params in results_data.values():
        data = params['data']
        all_available_percentiles.update(data['tail_lru_none'].keys())
    
    print(f"Available percentiles in the data: {sorted(list(all_available_percentiles))}")
    
    # Filter percentiles to those available in the data
    available_percentiles = [p for p in percentiles if p in all_available_percentiles]
    if not available_percentiles:
        print(f"Error: None of the requested percentiles {percentiles} are available in the data")
        available_percentiles = sorted([int(p) for p in all_available_percentiles if isinstance(p, (int, float)) or str(p).isdigit()])
        if available_percentiles:
            print(f"Using available percentiles instead: {available_percentiles}")
            percentiles = available_percentiles
        else:
            print("No valid percentiles found. Exiting.")
            return
    elif len(available_percentiles) < len(percentiles):
        print(f"Warning: Only using available percentiles: {available_percentiles}")
        percentiles = available_percentiles
    
    # Create ratio tables
    print("Creating ratio tables...")
    tlru_lru_ratio, tlru_thre_ratio = create_ratio_tables(results_data, percentiles, capacities)
    
    # Create improvement tables
    print("Creating improvement tables...")
    tlru_vs_lru_improvement, tlru_vs_thre_improvement = create_improvement_tables(results_data, percentiles, capacities)
    
    # Create output directory
    os.makedirs("latex_tables", exist_ok=True)
    
    # Create the example-style tables
    print("Creating multi-index tables...")
    example_table_tlru_lru = create_example_style_table(tlru_lru_ratio, percentiles)
    example_table_tlru_thre = create_example_style_table(tlru_thre_ratio, percentiles)
    example_table_tlru_lru_improvement = create_example_style_table(tlru_vs_lru_improvement, percentiles)
    example_table_tlru_thre_improvement = create_example_style_table(tlru_vs_thre_improvement, percentiles)
    
    # Generate LaTeX tables
    print("Generating LaTeX tables...")
    
    # T-LRU/LRU Latency Ratio Table
    latex_tlru_lru = create_latex_table(
        example_table_tlru_lru,
        "Relative latency of T-LRU compared to LRU with various $\\xi$",
        "tab:relative-lat-tlru-lru"
    )
    with open("latex_tables/tlru_lru_ratio_latex.tex", "w") as f:
        f.write(latex_tlru_lru)
    
    # T-LRU/Threshold-LRU Latency Ratio Table
    latex_tlru_thre = create_latex_table(
        example_table_tlru_thre,
        "Relative latency of T-LRU compared to Threshold-LRU with various $\\xi$",
        "tab:relative-lat-tlru-thre"
    )
    with open("latex_tables/tlru_thre_ratio_latex.tex", "w") as f:
        f.write(latex_tlru_thre)
    
    # T-LRU vs LRU Improvement Table
    latex_tlru_lru_imp = create_latex_improvement_table(
        example_table_tlru_lru_improvement,
        "Relative latency improvement of T-LRU over LRU with various $\\xi$",
        "tab:improvement-tlru-lru"
    )
    with open("latex_tables/tlru_lru_improvement_latex.tex", "w") as f:
        f.write(latex_tlru_lru_imp)
    
    # T-LRU vs Threshold-LRU Improvement Table
    latex_tlru_thre_imp = create_latex_improvement_table(
        example_table_tlru_thre_improvement,
        "Relative latency improvement of T-LRU over Threshold-LRU with various $\\xi$",
        "tab:improvement-tlru-thre"
    )
    with open("latex_tables/tlru_thre_improvement_latex.tex", "w") as f:
        f.write(latex_tlru_thre_imp)
    
    print("All LaTeX tables created in the 'latex_tables' directory.")

if __name__ == "__main__":
    main() 